package clear.experiment;

import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

import clear.parse.VoiceDetector;
import clear.propbank.PBArg;
import clear.propbank.PBInstance;
import clear.propbank.PBLoc;
import clear.propbank.PBReader;
import clear.treebank.TBNode;
import clear.treebank.TBReader;
import clear.treebank.TBTree;
import clear.util.IOUtil;
import clear.util.tuple.JObjectDoubleTuple;
import clear.util.tuple.JObjectIntTuple;

import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;

public class AnalyzePBArgs
{
	final String TOTAL = "TOTAL";
	
	ArrayList<TBTree> ls_trees;
	TBTree            tb_tree;
	
	ArrayList<ObjectIntOpenHashMap<String>> ls_numberedAdjuncts;
	ObjectIntOpenHashMap<String> m_verbs;
	ObjectIntOpenHashMap<String> m_verbArgs;
	ObjectIntOpenHashMap<String> m_verbArgNs;
	ObjectIntOpenHashMap<String> m_verbArgMs;
	ObjectIntOpenHashMap<String> m_preps;
	ObjectIntOpenHashMap<String> m_verbPreps;
	ObjectIntOpenHashMap<String> m_verbPrepNs;
	ObjectIntOpenHashMap<String> m_verbPrepMs;
	
	int n_count = 0;
	
	public AnalyzePBArgs(String rootPath, String outputFile)
	{
		init();
		read(rootPath);
	//	trimVerbPrep();
		print(outputFile);
	}
	
	void init()
	{
	//	initNumberedAdjunct();
		initVerbPrepPMI();
		initRequiredArgument();
	}
	
	void print(String outputFile)
	{
	//	printNumberedAdjunct(outputFile);
	//	printVerbPrepPMI(outputFile);
		printRequiredArgument(outputFile);
	}
	
	void read(String rootPath)
	{
		File root = new File(rootPath);			// v4.0
		File prop;
		String path;
		int    count;
		
		for (String corpusDir : root.list())	// ebc, wsj
		{
			path = rootPath + File.separator + corpusDir + File.separator + "prop";
			prop = new File(path);
			if (!prop.isDirectory())	continue;
			System.out.println(path);
		
			count = 0;
			for (String propFile : prop.list())
			{
				propFile = path + File.separator + propFile;
				if (++count % 100 == 0)	System.out.print(".");
				readParses(propFile.replaceAll("/prop/", "/parse/").replaceAll("\\.prop$", ".parse"));
				readProps (propFile);
			}
			
			System.out.println();
		}
	}
	
	void readParses(String parseFile)
	{
		TBReader reader = new TBReader(parseFile);
		TBTree   tree;
		
		ls_trees = new ArrayList<TBTree>();
		
		while ((tree = reader.nextTree()) != null)
			ls_trees.add(tree);
		
		ls_trees.trimToSize();
	}
	
	void readProps(String propFile)
	{
		PBReader   reader = new PBReader(propFile);
		PBInstance instance;
		
		while ((instance = reader.nextInstance()) != null)
		{
			if (!instance.type.endsWith("-v"))	continue;
			if (instance.getArgs().size() <= 1)	continue;
			instance.type = instance.type.substring(0, instance.type.length()-2);
			tb_tree = ls_trees.get(instance.treeIndex);
			
		//	processNumberedAdjunct(instance);
		//	processVerbPrepPMI(instance);
			processRequiredArgument(instance);
		}
	}
	
	// ----------------------------- NumberedAdjunct -----------------------------
	
	void initNumberedAdjunct()
	{
		ls_numberedAdjuncts = new ArrayList<ObjectIntOpenHashMap<String>>();
		
		for (int i=0; i<=5; i++)
			ls_numberedAdjuncts.add(new ObjectIntOpenHashMap<String>());
	}
	
	void processNumberedAdjunct(PBInstance instance)
	{
		ObjectIntOpenHashMap<String> map;
		TBNode node;
		
		for (PBArg arg : instance.getArgs())
		{
			if (!arg.label.matches("ARG\\d"))	continue;
			map = ls_numberedAdjuncts.get(Integer.parseInt(arg.label.substring(3,4)));
			
			for (PBLoc loc : arg.getLocs())
			{
				node = tb_tree.getNode(loc.terminalId, loc.height);
				if (node.isEmptyCategoryRec())	continue;
				if (!node.isPos("PP"))			continue;
				
				for (TBNode child : node.getChildren())
				{
					if (child.isPos("IN"))
					{
						increment(map, child.form.toLowerCase());
						break;
					}
				}
			}
			
			increment(map, TOTAL);
		}
	}
	
	void printNumberedAdjunct(String outputFile)
	{
		PrintStream fout = IOUtil.createPrintFileStream(outputFile);
		ObjectIntOpenHashMap<String> map;
		ArrayList<JObjectIntTuple<String>> list;
		int total;	String key;
		
		for (int i=0; i<ls_numberedAdjuncts.size(); i++)
		{
			map  = ls_numberedAdjuncts.get(i);
			list = new ArrayList<JObjectIntTuple<String>>();
			
			for (ObjectCursor<String> cur : map.keys())
			{
				key = cur.value;
				if (key.equals(TOTAL))	continue;
				list.add(new JObjectIntTuple<String>(key, map.get(key)));
			}
			
			Collections.sort(list);
			
			total = map.get(TOTAL);
			fout.println("ARG"+i+"\t"+total);
			
			for (JObjectIntTuple<String> tup : list)
				fout.println(tup.object+"\t"+tup.integer+"\t"+(double)tup.integer*100/total);
		}
		
		fout.flush();	fout.close();
	}
	
	// ----------------------------- VerbPrepPMI -----------------------------
	
	void initVerbPrepPMI()
	{
		m_verbs     = new ObjectIntOpenHashMap<String>();
		m_preps     = new ObjectIntOpenHashMap<String>();
		m_verbPreps = new ObjectIntOpenHashMap<String>();
		m_verbPrepNs = new ObjectIntOpenHashMap<String>();
		m_verbPrepMs = new ObjectIntOpenHashMap<String>();
		m_verbArgs  = new ObjectIntOpenHashMap<String>();
		m_verbArgNs = new ObjectIntOpenHashMap<String>();
		m_verbArgMs = new ObjectIntOpenHashMap<String>();
	}
	
	void processVerbPrepPMI(PBInstance instance)
	{
		boolean isPassive = VoiceDetector.getPassive(tb_tree.getNode(instance.predicateId, 0)) > 0;
		String vLemma = instance.type, pLemma;
		TBNode node;
		
		increment(m_verbs, vLemma);
		increment(m_verbs, TOTAL);
		
		for (PBArg arg : instance.getArgs())
		{
			if (!arg.label.startsWith("ARG"))				continue;
			if ( arg.label.matches("ARGM-MOD|ARGM-NEG"))	continue;
			
			if (arg.label.matches("ARG\\d"))
				increment(m_verbArgNs, vLemma);
			else
				increment(m_verbArgMs, vLemma);
			
			increment(m_verbArgs, vLemma);
			increment(m_verbArgs, TOTAL);

			for (PBLoc loc : arg.getLocs())
			{
				node = tb_tree.getNode(loc.terminalId, loc.height);
				if (node.isEmptyCategoryRec())	continue;
				if (!node.isPos("PP"))			continue;
				
				for (TBNode child : node.getChildren())
				{
					if (child.isPos("IN"))
					{
						pLemma = child.form.toLowerCase();
						
						if (!(isPassive && arg.label.equals("ARG0") && pLemma.equals("by")))
						{
							String key = vLemma+"_"+pLemma;
							increment(m_verbPreps, key);
							
							if (arg.label.matches("ARG\\d"))
								increment(m_verbPrepNs, key);
							else
								increment(m_verbPrepMs, key);
							
							increment(m_preps, pLemma);
							increment(m_preps, TOTAL);
						//	if (vLemma.equals("buy") && pLemma.equals("at"))
						//		System.out.println(instance.rolesetId+" "+arg.label+"\n"+tb_tree.getRootNode().toWords());
						}
						
						break;
					}
				}
			}
		}
	}
	
	void trimVerbPrep()
	{
		String key;
		int    value;
		String[] tmp;
		
		for (ObjectCursor<String> cur : m_verbPreps.keys())
		{
			key   = cur.value;
			tmp   = key.split("_");
			value = m_verbPreps.get(key);
		
			if (value <= 1)
			{
				decrement(m_verbPreps , key  , value);
				decrement(m_verbPrepNs, key  , value);
				decrement(m_verbPrepMs, key  , value);
				decrement(m_verbArgs , key   , value);
				decrement(m_verbArgs , TOTAL , value);
				decrement(m_verbArgNs, key   , value);
				decrement(m_verbArgMs, key   , value);
				decrement(m_verbs    , tmp[0], value);
				decrement(m_verbs    , TOTAL , value);
				decrement(m_preps    , tmp[1], value);
				decrement(m_preps    , TOTAL , value);
			}
		}
	}
	
	void printVerbPrepPMI(String outputFile)
	{
		PrintStream fout = IOUtil.createPrintFileStream(outputFile);
		ArrayList<JObjectDoubleTuple<String>> list = new ArrayList<JObjectDoubleTuple<String>>();
		int nVerb, nVerbTotal, nPrep, nVerbArg, nVerbArgTotal, nVerbPrep;
		String key;	String[] tmp;
		@SuppressWarnings("unused")
		double pmi, pv, p, v, pnv, pmv;
		double smooth = 0.000001;
		
		nVerbTotal    = m_verbs   .get(TOTAL);
		nVerbArgTotal = m_verbArgs.get(TOTAL);
		
		for (ObjectCursor<String> cur : m_verbPreps.keys())
		{
			key = cur.value;
			tmp = key.split("_");
			
			nVerbPrep = m_verbPreps.get(key);
			if (nVerbPrep == 0)	continue;
			nVerb     = m_verbs    .get(tmp[0]);
			nVerbArg  = m_verbArgs .get(tmp[0]);
			nPrep     = m_preps    .get(tmp[1]);
			
			pv = (double)nVerbPrep / nVerbArg;
			p  = (double)nPrep / nVerbArgTotal;
			v  = (double)nVerb / nVerbTotal;
			
			pnv = smooth + (double)m_verbPrepNs.get(key) / m_verbArgNs.get(tmp[0]);
			pmv = smooth + (double)m_verbPrepMs.get(key) / m_verbArgMs.get(tmp[0]);
			if (m_verbArgMs.get(tmp[0]) == 0)	pmv = smooth;
			pmi = Math.log(pnv / pmv);
			
		//	pmi = getPMI(pv, p);
		//	pmi /= -(Math.log(pv) + Math.log(v));
			
			list.add(new JObjectDoubleTuple<String>(key, pmi));
		}
		
		Collections.sort(list);
		
		for (JObjectDoubleTuple<String> tup : list)
		{
			key = tup.object;
			tmp = key.split("_");
			pnv = smooth + (double)m_verbPrepNs.get(key) / m_verbArgNs.get(tmp[0]);
			pmv = smooth + (double)m_verbPrepMs.get(key) / m_verbArgMs.get(tmp[0]);
			if (m_verbArgMs.get(tmp[0]) == 0)	pmv = smooth;
			
			fout.println(key+"\t"+pnv+"\t"+pmv+"\t"+tup.value);
		//	fout.println(key+"\t"+m_verbPreps.get(key)+"\t"+m_preps.get(tmp[1])+"\t"+m_verbs.get(tmp[0])+"\t"+tup.value);
		}
		
		fout.flush();	fout.close();
	}
	
// ----------------------------- RequiredArgument -----------------------------
	
	void initRequiredArgument()
	{
		m_verbs     = new ObjectIntOpenHashMap<String>();
		m_preps     = new ObjectIntOpenHashMap<String>();
		m_verbArgNs = new ObjectIntOpenHashMap<String>();
		m_verbArgMs = new ObjectIntOpenHashMap<String>();
	}
	
	void processRequiredArgument(PBInstance instance)
	{
		boolean isPassive = VoiceDetector.getPassive(tb_tree.getNode(instance.predicateId, 0)) > 0;
		TBNode  predicate = tb_tree.getNode(instance.predicateId, 0);
		String  sentence  = predicate.getSentenceGroup();
		if (!(!isPassive && sentence != null && sentence.equals("SQ")))	return;
		n_count++;
		
		String vLemma = instance.type;
		String key;
		
		increment(m_verbs, vLemma);
		increment(m_verbs, TOTAL);
		
		for (PBArg arg : instance.getArgs())
		{
			if (!arg.label.startsWith("ARG"))				continue;
			if ( arg.label.matches("ARGM-MOD|ARGM-NEG"))	continue;

			if (vLemma.equals("buy") && arg.label.equals("ARGM-LOC"))
				System.out.println(instance.predicateId+" "+arg.getLocs()+" "+tb_tree.getRootNode().toWords());
			
			key = vLemma+"_"+arg.label;
			
			if (arg.label.matches("ARG\\d"))
				increment(m_verbArgNs, key);
			else
			{
				increment(m_verbArgMs, key);
				increment(m_preps, arg.label);
				increment(m_preps, TOTAL);
			}
		}
	}
	
	void printRequiredArgument(String outputFile)
	{
		PrintStream fout = IOUtil.createPrintFileStream(outputFile);
		System.out.println(n_count);
		
		int nVerb, nVerbTotal, nPrep, nPrepTotal, nVerbArg;
		String key;	String[] tmp;
		double pmi, pv, p, v;
		double thresh = 0;
		
		HashMap<String, ArrayList<JObjectDoubleTuple<String>>> map = new HashMap<String, ArrayList<JObjectDoubleTuple<String>>>();

		for (ObjectCursor<String> cur : m_verbs.keys())
		{
			if (cur.value.equals(TOTAL))	continue;
			map.put(cur.value, new ArrayList<JObjectDoubleTuple<String>>());
		}
		
		ArrayList<JObjectDoubleTuple<String>> list;
		
		for (ObjectCursor<String> cur : m_verbArgNs.keys())
		{
			key  = cur.value;
			tmp  = key.split("_");
			list = map.get(tmp[0]);
			
			nVerbArg = m_verbArgNs.get(key);
			nVerb    = m_verbs.get(tmp[0]);
			pmi      = (double)nVerbArg * 100 / nVerb;
			
			if (pmi > thresh)	list.add(new JObjectDoubleTuple<String>(tmp[1], pmi));
		}
		
		nVerbTotal = m_verbs.get(TOTAL);
		nPrepTotal = m_preps.get(TOTAL);
		
		for (ObjectCursor<String> cur : m_verbArgMs.keys())
		{
			key  = cur.value;
			tmp  = key.split("_");
			list = map.get(tmp[0]);
			
			nVerbArg = m_verbArgMs.get(key);
			nVerb    = m_verbs.get(tmp[0]);
			nPrep    = m_preps.get(tmp[1]);
			pv       = (double)nVerbArg / nVerb;
			p        = (double)nPrep / nPrepTotal;
			v        = (double)nVerb / nVerbTotal;
			pmi = getPMI(pv, p);
			pmi /= -(Math.log(pv) + Math.log(v));
			
			if (pmi > 0)	list.add(new JObjectDoubleTuple<String>(tmp[1], pmi));
		}
		
		for (String verb : map.keySet())
		{
			list = map.get(verb);
			Collections.sort(list);
			
			StringBuilder build = new StringBuilder();
			build.append(verb);
			build.append("\t");
			build.append(m_verbs.get(verb));
			
			for (JObjectDoubleTuple<String> tup : list)
			{
				build.append("\t");
				build.append(tup.toString());
			}
			
			fout.println(build.toString());
		}
				
		fout.flush();	fout.close();
	}
	
	double getPMI(double pxy, double px)
	{
		return Math.log(pxy / px);
	}
	
	void increment(ObjectIntOpenHashMap<String> map, String key)
	{
		map.put(key, map.get(key)+1);
	}
	
	void decrement(ObjectIntOpenHashMap<String> map, String key, int dec)
	{
		map.put(key, map.get(key)-dec);
	}
	
	double log2(double d)
	{
		return Math.log(d) / Math.log(2);
	}
	
	boolean isArgM(String label)
	{
		return label.startsWith("ARGM") && !label.equals("ARGM-NEG") && !label.equals("ARGM-MOD");
	}
	
	public static void main(String[] args)
	{
		new AnalyzePBArgs(args[0], args[1]);
	}
}
